# ACUPUNCTURE CATE ANALYSIS - SAMPLE SIZE ANALYSIS WITH THEORETICAL BOUNDS

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

def theoretical_bound(m, beta, N, delta, OPT):
    """Compute theoretical bound (1 - (N*ln(2N/delta)/m)^beta) * OPT"""
    term = N * np.log(2 * N / delta) / m
    if term >= 1:
        return 0  # Bound becomes meaningless
    return (1 - term**beta) * OPT

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class AcupunctureSampleSizeAnalyzer:
    """Acupuncture CATE allocation with sample size analysis and theoretical bounds."""

    def __init__(self, random_seed=42):
        self.random_seed = random_seed
        np.random.seed(random_seed)
        print(f"Acupuncture Sample Size Analyzer initialized with seed {random_seed}")

    def process_acupuncture_data(self, df, outcome_col='pk5', treatment_col='group'):
        """Process acupuncture dataset - SAME AS ORIGINAL ACUPUNCTURE CODE."""
        print(f"Processing acupuncture data with {len(df)} patients")
        print(f"Available columns: {list(df.columns)}")

        df_processed = df.copy()

        if treatment_col not in df_processed.columns:
            raise ValueError(f"Missing required treatment column: {treatment_col}")
        if outcome_col not in df_processed.columns:
            raise ValueError(f"Missing required outcome column: {outcome_col}")

        df_processed['treatment'] = df_processed[treatment_col]
        df_processed['outcome'] = df_processed[outcome_col]

        if 'pk1' in df_processed.columns:
            df_processed['baseline_headache'] = df_processed['pk1']
        else:
            df_processed['baseline_headache'] = 0

        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} patients")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome (12-month headache score) statistics: mean={df_processed['outcome'].mean():.2f}, std={df_processed['outcome'].std():.2f}")

        if 'baseline_headache' in df_processed.columns:
            print(f"Baseline headache stats: mean={df_processed['baseline_headache'].mean():.2f}, std={df_processed['baseline_headache'].std():.2f}")

        return df_processed

    def create_age_chronicity_groups(self, df, n_groups=30, min_size=6):
        """Create age-chronicity groups - SAME AS ACUPUNCTURE CODE."""
        print(f"Creating age-chronicity interaction groups (target: {n_groups})")

        if 'age' not in df.columns or 'chronicity' not in df.columns:
            print("No age or chronicity variables found")
            return []

        age = df['age'].fillna(df['age'].median())
        chronicity = df['chronicity'].fillna(df['chronicity'].median())

        age_norm = (age - age.min()) / (age.max() - age.min()) if age.max() > age.min() else age * 0
        chron_norm = (chronicity - chronicity.min()) / (chronicity.max() - chronicity.min()) if chronicity.max() > chronicity.min() else chronicity * 0
        interaction_score = age_norm * chron_norm

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(interaction_score, percentiles)
        bins = np.digitize(interaction_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_chronicity_group_{i}',
                    'indices': indices,
                    'type': 'age_chronicity'
                })

        print(f"Created {len(groups)} age-chronicity interaction groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_age_groups(self, df, n_groups=30, min_size=6):
        """Create age groups - SAME AS ACUPUNCTURE CODE."""
        print(f"Creating age groups (target: {n_groups})")

        if 'age' not in df.columns:
            print("No age variable found")
            return []

        age = df['age'].fillna(df['age'].median())
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(age, percentiles)
        bins = np.digitize(age, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_group_{i}',
                    'indices': indices,
                    'type': 'age'
                })

        print(f"Created {len(groups)} age groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_chronicity_groups(self, df, n_groups=30, min_size=6):
        """Create chronicity groups - SAME AS ACUPUNCTURE CODE."""
        print(f"Creating chronicity groups (target: {n_groups})")

        if 'chronicity' not in df.columns:
            print("No chronicity variable found")
            return []

        chronicity = df['chronicity'].fillna(df['chronicity'].median())
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(chronicity, percentiles)
        bins = np.digitize(chronicity, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'chronicity_group_{i}',
                    'indices': indices,
                    'type': 'chronicity'
                })

        print(f"Created {len(groups)} chronicity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_baseline_headache_groups(self, df, n_groups=30, min_size=6):
        """Create baseline headache groups - SAME AS ACUPUNCTURE CODE."""
        print(f"Creating baseline headache groups (target: {n_groups})")

        if 'pk1' not in df.columns:
            print("No baseline headache data available")
            return []

        baseline = df['pk1'].fillna(df['pk1'].median())
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(baseline, percentiles)
        bins = np.digitize(baseline, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'baseline_headache_{i}',
                    'indices': indices,
                    'type': 'baseline_headache'
                })

        print(f"Created {len(groups)} baseline headache groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_multidimensional_groups(self, df, n_groups=30, min_size=6):
        """Create multidimensional groups - SAME AS ACUPUNCTURE CODE."""
        print(f"Creating multidimensional composite groups (target: {n_groups})")

        continuous_vars = ['age', 'chronicity', 'pk1']
        available_vars = [col for col in continuous_vars if col in df.columns]

        if len(available_vars) < 2:
            print("Not enough continuous variables for multidimensional grouping")
            return []

        print(f"Using continuous variables: {available_vars}")

        composite_score = pd.Series(0.0, index=df.index)
        for var in available_vars:
            values = df[var].fillna(df[var].median())
            if values.max() > values.min():
                normalized = (values - values.min()) / (values.max() - values.min())
            else:
                normalized = values * 0
            composite_score += normalized

        composite_score = composite_score / len(available_vars)

        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(composite_score, percentiles)
        bins = np.digitize(composite_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'multidim_group_{i}',
                    'indices': indices,
                    'type': 'multidimensional'
                })

        print(f"Created {len(groups)} multidimensional groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_covariate_forest_groups(self, df, n_groups=30, min_size=6):
        """Create covariate forest groups - SAME AS ACUPUNCTURE CODE."""
        print(f"Creating covariate-based forest groups (target: {n_groups})")

        feature_cols = ['age', 'sex', 'migraine', 'chronicity', 'pk1']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for covariate clustering")
            return []

        X = df[available_features].copy()

        for col in X.columns:
            if X[col].dtype == 'object':
                le = LabelEncoder()
                X[col] = X[col].fillna('missing')
                X[col] = le.fit_transform(X[col])
            else:
                if X[col].isna().any():
                    X[col] = X[col].fillna(X[col].median())

        cluster_features = StandardScaler().fit_transform(X.values)
        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'covariate_cluster_{i}',
                    'indices': indices,
                    'type': 'covariate_cluster'
                })

        print(f"Created {len(groups)} covariate-based groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure balance and compute CATE - SAME AS ACUPUNCTURE CODE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            # Reverse sign: lower headache scores = better (treatment benefit)
            cate = -(treated_outcomes.mean() - control_outcomes.mean())

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1] - SAME AS ACUPUNCTURE CODE."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def simulate_sampling_trial(self, groups, sample_size, trial_seed):
        """Simulate sampling trial."""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Initialize tau estimates
        tau_estimates = np.zeros(n_groups)
        sample_counts = np.zeros(n_groups)

        # Perform sampling: choose group uniformly, sample Bernoulli(tau(u))
        for _ in range(sample_size):
            group_idx = np.random.randint(n_groups)
            sample = np.random.binomial(1, tau_true[group_idx])

            sample_counts[group_idx] += 1
            if sample_counts[group_idx] == 1:
                tau_estimates[group_idx] = sample
            else:
                tau_estimates[group_idx] = ((sample_counts[group_idx] - 1) * tau_estimates[group_idx] + sample) / sample_counts[group_idx]

        # Groups with no samples get estimate 0
        tau_estimates[sample_counts == 0] = 0

        return tau_estimates, sample_counts

    def analyze_sample_size_performance(self, groups, sample_sizes, budget_percentages, n_trials=50):
        """Analyze performance vs sample size."""
        print(f"Analyzing sample size performance with {len(groups)} groups")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Calculate budgets
        budgets = [max(1, int(p * n_groups)) for p in budget_percentages]
        print(f"Budget percentages {budget_percentages} → K values {budgets}")

        # Calculate optimal values
        optimal_values = {}
        for i, K in enumerate(budgets):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_values[budget_percentages[i]] = np.sum(tau_true[optimal_indices])

        # Run trials
        results = {bp: {'sample_sizes': [], 'values': [], 'stds': []} for bp in budget_percentages}

        for sample_size in sample_sizes:
            print(f"  Sample size {sample_size}...")

            budget_trial_values = {bp: [] for bp in budget_percentages}

            for trial in range(n_trials):
                tau_estimates, sample_counts = self.simulate_sampling_trial(groups, sample_size, trial)

                for i, K in enumerate(budgets):
                    bp = budget_percentages[i]

                    # Select top K based on estimates
                    selected_indices = np.argsort(tau_estimates)[-K:]

                    # Compute realized value with true tau
                    realized_value = np.sum(tau_true[selected_indices])
                    budget_trial_values[bp].append(realized_value)

            # Store results
            for bp in budget_percentages:
                results[bp]['sample_sizes'].append(sample_size)
                results[bp]['values'].append(np.mean(budget_trial_values[bp]))
                results[bp]['stds'].append(np.std(budget_trial_values[bp]))

        return results, optimal_values

    def plot_sample_size_analysis(self, results, optimal_values, method_name, budget_percentages, n_groups):
        """Create 6 plots (one per budget) for sample size analysis with theoretical bounds"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()

        # Calculate parameters for theoretical bounds
        delta = 0.05

        print(f"\nPlotting {method_name} (N={n_groups})")
        print("="*60)

        for i, bp in enumerate(budget_percentages):
            ax = axes[i]

            # Get data for this budget
            sample_sizes = results[bp]['sample_sizes']
            values = results[bp]['values']
            stds = results[bp]['stds']
            optimal_val = optimal_values[bp]

            # Normalize all values by optimal value
            values_norm = np.array(values) / optimal_val
            stds_norm = np.array(stds) / optimal_val

            # Plot empirical performance curve
            ax.errorbar(sample_sizes, values_norm, yerr=stds_norm,
                      marker='o', capsize=5, capthick=3, linewidth=6, markersize=8,
                      label='Empirical data', color='blue', alpha=0.8)

            # Plot optimal value (normalized to 1)
            ax.axhline(y=1.0, color='black', linestyle=':', linewidth=2,
                      label='Optimal (1.0)', alpha=0.8)

            m_smooth = np.linspace(min(sample_sizes), max(sample_sizes), 200)

            # Plot reference curves
            ref_curve_05 = [theoretical_bound(m, 0.5, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]
            ref_curve_10 = [theoretical_bound(m, 1.0, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]

            ax.plot(m_smooth, ref_curve_05, 'red', linestyle=(0, (3, 2)), linewidth=6,
                  label='FullCATE', alpha=0.8)
            ax.plot(m_smooth, ref_curve_10, 'green', linestyle=(0, (3, 1, 1, 1)), linewidth=6,
                  label='ALLOC', alpha=0.8)

            # Set labels
            ax.set_xlabel('Sample size', fontsize=23)
            ax.set_ylabel('Normalized allocation value', fontsize=23)
            ax.set_title(f'Budget = {bp*100:.0f}% (K={max(1, int(bp * n_groups))})', fontsize=24, fontweight='bold')

            # Larger legend but keep it in corner
            ax.legend(fontsize=21, framealpha=0.9)
            ax.grid(True, alpha=0.4, linewidth=1)

            # Make tick labels larger
            ax.tick_params(axis='both', which='major', labelsize=16, width=1.5, length=5)

            y_min = 0.2
            y_max = 1.05  # Slightly above optimal
            ax.set_ylim(y_min, y_max)

            for spine in ax.spines.values():
                spine.set_linewidth(1.5)

        plt.suptitle(f'{method_name} (N={n_groups})', fontsize=24, fontweight='bold')
        plt.tight_layout()

        clean_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_')
        pdf_filename = f"{clean_name}_N{n_groups}_sample_size_analysis.pdf"
        plt.savefig(pdf_filename, format='pdf', dpi=300, bbox_inches='tight')
        print(f"Saved plot as: {pdf_filename}")

        plt.show()

        print(f"Plot complete for {method_name}")

def run_acupuncture_sample_size_analysis(df_acupuncture, sample_size_range=None, budget_percentages=None, n_trials=50,
                                       outcome_col='pk5', treatment_col='group'):
    """Run sample size analysis on Acupuncture dataset with theoretical bounds."""

    if sample_size_range is None:
        sample_size_range = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]

    if budget_percentages is None:
        budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    print("ACUPUNCTURE SAMPLE SIZE ANALYSIS - EMPIRICAL VS THEORETICAL BOUNDS")
    print(f"Sample sizes: {sample_size_range}")
    print(f"Budget percentages: {budget_percentages}")
    print(f"Trials per sample size: {n_trials}")
    print("="*80)

    # Define Acupuncture grouping methods
    methods = [
        ('Age-Chronicity Interaction', lambda analyzer, df: analyzer.create_age_chronicity_groups(df, n_groups=30, min_size=6)),
        ('Age Groups', lambda analyzer, df: analyzer.create_age_groups(df, n_groups=30, min_size=6)),
        ('Chronicity Groups', lambda analyzer, df: analyzer.create_chronicity_groups(df, n_groups=30, min_size=6)),
        ('Multidimensional Composite', lambda analyzer, df: analyzer.create_multidimensional_groups(df, n_groups=30, min_size=6)),
        ('Baseline Headache', lambda analyzer, df: analyzer.create_baseline_headache_groups(df, n_groups=30, min_size=6)),
        ('Covariate Forest', lambda analyzer, df: analyzer.create_covariate_forest_groups(df, n_groups=30, min_size=6))
    ]

    all_results = {}

    for method_name, method_func in methods:
        print(f"\n{'='*80}")
        print(f"ANALYZING ACUPUNCTURE METHOD: {method_name}")
        print("="*80)

        try:
            analyzer = AcupunctureSampleSizeAnalyzer()
            df_processed = analyzer.process_acupuncture_data(df_acupuncture, outcome_col=outcome_col, treatment_col=treatment_col)

            groups = method_func(analyzer, df_processed)

            if len(groups) < 10:
                print(f"Too few groups ({len(groups)}) for {method_name} - skipping")
                continue

            groups = analyzer.normalize_cates(groups)

            # Run sample size analysis
            results, optimal_values = analyzer.analyze_sample_size_performance(
                groups, sample_size_range, budget_percentages, n_trials
            )

            all_results[method_name] = {
                'results': results,
                'optimal_values': optimal_values,
                'n_groups': len(groups)
            }

            # Create plots with theoretical bounds
            print(f"Creating plots for {method_name}...")
            analyzer.plot_sample_size_analysis(
                results, optimal_values, method_name, budget_percentages, len(groups)
            )

            # Print summary
            print(f"\nSummary for {method_name}:")
            print(f"Number of groups: {len(groups)}")
            print("Optimal values by budget:")
            for bp in budget_percentages:
                print(f"  {bp*100:.0f}%: {optimal_values[bp]:.3f}")

        except Exception as e:
            print(f"Error with {method_name}: {e}")
            continue

    return all_results

# Example usage
if __name__ == "__main__":
    # Load acupuncture dataset
    df_acupuncture = pd.read_stata('acupuncture.dta')

    # Run sample size analysis with theoretical bounds
    sample_sizes = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]
    budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    results = run_acupuncture_sample_size_analysis(
        df_acupuncture,
        sample_size_range=sample_sizes,
        budget_percentages=budget_percentages,
        n_trials=50
    )